{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from simpletransformers.classification import ClassificationModel, ClassificationArgs\n",
    "from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, accuracy_score\n",
    "import torch\n",
    "from transformers_interpret import SequenceClassificationExplainer\n",
    "from transformers import AutoTokenizer\n",
    "import json\n",
    "\n",
    "cuda_available = torch.cuda.is_available()\n",
    "cuda_available"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#load model fine-tuned on ONQ2\n",
    "model_wrapper = ClassificationModel(\n",
    "    \"bert\",\n",
    "    r\"02_output data\\models\\onq2m\",\n",
    "    use_cuda=cuda_available\n",
    ")\n",
    "model = model_wrapper.model\n",
    "\n",
    "#load tokenizer\n",
    "tokenizer = model_wrapper.tokenizer"
   ],
   "id": "97f39e669a203b38",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#initialize model and tokenizer\n",
    "cls_explainer = SequenceClassificationExplainer(model, tokenizer)"
   ],
   "id": "bee4f95bd194877f",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#load in-corpus predictions on ONQ2\n",
    "df_onq2 = pd.read_excel(r\"02_output data\\onq2_onq2m_predictions.xlsx\")\n",
    "\n",
    "#print df characteristics\n",
    "print(df_onq2.head(5))\n",
    "print(df_onq2.shape)"
   ],
   "id": "28928324bf9230a6",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#attribution scores for exemplary response\n",
    "text = df_onq2.iloc[0, 0]\n",
    "attributions = cls_explainer(text)\n",
    "cls_explainer.visualize()\n",
    "print(attributions)"
   ],
   "id": "192064859026cb91",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#generate empty dictionaries to store attributions scores for positive (tokens_1) and negative (tokens_0) predictions\n",
    "tokens_1 = {}\n",
    "tokens_0 = {}\n",
    "\n",
    "#iterate through all pairs of responses and predicted labels; zip combines the two lists into one list of tuples\n",
    "for text, prediction in zip(df_onq2['Text'], df_onq2['Predicted Label']):\n",
    "\n",
    "    #calculate attribution scores for the given response; returned as a list of tuples\n",
    "    attributions = cls_explainer(text)\n",
    "\n",
    "    #iterate through all tuples (i.e., pairs of tokens and attribution scores)\n",
    "    for token, attribution in attributions:\n",
    "\n",
    "        #condition if positive prediction\n",
    "        if prediction == 1:\n",
    "\n",
    "            #if not existing yet, word is added to dictionary together with an empty list\n",
    "            if token not in tokens_1:\n",
    "                tokens_1[token] = []\n",
    "\n",
    "            #attribution score is added to list\n",
    "            tokens_1[token].append(attribution)\n",
    "\n",
    "        #condition if negative prediction\n",
    "        else:\n",
    "\n",
    "            #if not existing yet, word is added to list\n",
    "            if token not in tokens_0:\n",
    "                tokens_0[token] = []\n",
    "\n",
    "            #attribution score is added to list\n",
    "            tokens_0[token].append(attribution)"
   ],
   "id": "ad0177cffff7ee73",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#convert attribution dictionaries into JSON formated strings\n",
    "json_tokens_1 = json.dumps(tokens_1)\n",
    "json_tokens_0 = json.dumps(tokens_0)\n",
    "\n",
    "#export JSON strings\n",
    "with open(r\"02_output data\\onq2_attributions_pos_predictions.json\", \"w\") as outfile:\n",
    "    outfile.write(json_tokens_1)\n",
    "with open(r\"02_output data\\onq2_attributions_neg_predictions.json\", \"w\") as outfile:\n",
    "    outfile.write(json_tokens_0)"
   ],
   "id": "79f3b1392ad7e088",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#optional: re-load attribution dictionaries from json files\n",
    "with open(r\"02_output data\\onq2_attributions_pos_predictions.json\") as f:\n",
    "    tokens_1 = json.load(f)\n",
    "with open(r\"02_output data\\onq2_attributions_neg_predictions.json\") as f:\n",
    "    tokens_0 = json.load(f)"
   ],
   "id": "1d20c0b8843c33e5",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#calculate number of different tokens in attribution dictionaries\n",
    "n_tokens_1 = len(tokens_1)\n",
    "n_tokens_0 = len(tokens_0)\n",
    "\n",
    "print(n_tokens_1)\n",
    "print(n_tokens_0)"
   ],
   "id": "cc82e656a5ad4f87",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "#calculate average attribution score for each token; return dictionaries with pairs of tokens and average scores\n",
    "avg_score_1 = {token: sum(scores)/len(scores) for token, scores in tokens_1.items()}\n",
    "avg_score_0 = {token: sum(scores)/len(scores) for token, scores in tokens_0.items()}\n",
    "\n",
    "#calculate frequency of each token; return dictionaries with pairs of tokens and frequency\n",
    "freq_tokens_1 = {token: len(scores) for token, scores in tokens_1.items()}\n",
    "freq_tokens_0 = {token: len(scores) for token, scores in tokens_0.items()}\n",
    "\n",
    "#create joint dictionary including both the average attribution score and frequency\n",
    "freq_score_tokens_1 = [(token, score, freq_tokens_1.get(token)) for token, score in avg_score_1.items()]\n",
    "freq_score_tokens_0 = [(token, score, freq_tokens_0.get(token)) for token, score in avg_score_0.items()]\n",
    "\n",
    "#only keep tokens with frequency > 25\n",
    "filtered_freq_score_tokens_1 = [token for token in freq_score_tokens_1 if token[2] > 25]\n",
    "filtered_freq_score_tokens_0 = [token for token in freq_score_tokens_0 if token[2] > 25]\n",
    "\n",
    "#sort filtered tokens by attribution score\n",
    "sorted_filtered_freq_score_tokens_1 = sorted(filtered_freq_score_tokens_1, key=lambda x: x[1], reverse=True)\n",
    "sorted_filtered_freq_score_tokens_0 = sorted(filtered_freq_score_tokens_0, key=lambda x: x[1], reverse=True)\n",
    "\n",
    "#print top tokens by positive attribution scores (only considering those appearing more than 25 times)\n",
    "print(sorted_filtered_freq_score_tokens_1)\n",
    "print(sorted_filtered_freq_score_tokens_0)"
   ],
   "id": "2d993b7e8ef5c25f",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
